# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" MindVison Classification training script. """

from mindspore import context, load_checkpoint, load_param_into_net
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.communication.management import init, get_rank, get_group_size

from mindvision.common.check_param import Validator, Rel
from mindvision.classification.utils import parse_args
from mindvision.common.utils.config import Config
from mindvision.classification.dataset.base_dataset import create_dataset
from mindvision.classification.models.build_train import build_model
from mindvision.classification.models.create_loss import create_loss
from mindvision.classification.models.optimizer import create_optimizer


def main(pargs):
    # set config context
    config = Config(pargs.config)
    context.set_context(mode=config.mode,
                        device_target=config.device_target,
                        save_graphs=False)

    # run distribute
    if config.run_distribute:
        init()
        context.set_auto_parallel_context(device_num=get_group_size(),
                                          parallel_mode=ParallelMode.DATA_PARALLEL,
                                          gradients_mean=True)
        ckpt_save_dir = config.ckpt_path + "ckpt_" + str(get_rank()) + "/"
    else:
        ckpt_save_dir = config.ckpt_path

    # perpare dataset
    dataset_train = create_dataset(config)
    Validator.check_int(dataset_train.get_dataset_size(), 0, Rel.GT)
    batches_per_epoch = dataset_train.get_dataset_size()

    # set network
    network = build_model(config)

    # set loss, optimizer
    network_loss = create_loss(config)
    network_opt = create_optimizer(network.trainable_params(), config, batches_per_epoch)

    if config.TRAIN.pre_trained:
        # load pretrain model
        param_dict = load_checkpoint(config.pretrained_model)
        load_param_into_net(network, param_dict)

    # set checkpoint for the network
    ckpt_config = CheckpointConfig(
        save_checkpoint_steps=config.save_checkpoint_steps,
        keep_checkpoint_max=config.keep_checkpoint_max)
    ckpt_callback = ModelCheckpoint(prefix=config.model_name,
                                    directory=ckpt_save_dir,
                                    config=ckpt_config)

    # init the whole Model
    model = Model(network,
                  network_loss,
                  network_opt,
                  metrics={"Accuracy": Accuracy()})

    # begin to train
    print(f'[Start training `{config.model_name}`]')
    print("=" * 80)
    model.train(config.epochs,
                dataset_train,
                callbacks=[ckpt_callback, LossMonitor()],
                dataset_sink_mode=config.dataset_sink_mode)
    print(f'[End of training `{config.model_name}`]')


if __name__ == '__main__':
    args = parse_args()
    main(args)
